import torch

from .tucker_conv_fixed import Conv2d_tucker_fixed
from .tucker_conv_adaptive import Conv2d_tucker_adaptive


def Conv2d_tucker(adaptive: bool, in_channels: int, out_channels: int, kernel_size, stride, padding, dilation,
                  groups: int = 1, bias: bool = True, padding_mode: str = "zeros", dtype=None, device=None,
                  low_rank_percent=None, tau: float = 0.01
                  ):
    if adaptive:
        return Conv2d_tucker_adaptive(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
                                      padding_mode, dtype, device, low_rank_percent, tau=tau,
                                      )
    else:
        return Conv2d_tucker_fixed(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
                                   bias, padding_mode, dtype, device, low_rank_percent=tau,
                                   )
